import numpy as np
import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
from functorch.experimental import replace_all_batch_norm_modules_
import copy
from numba import jit

def load_model(model, path):
	checkpoint = torch.load(path)['state_dict']
	model.load_state_dict(checkpoint)
	return model


def warp_model(network):
	network_copy = copy.deepcopy(network)
	network_copy = network_copy.cuda()
	replace_all_batch_norm_modules_(network_copy)
	net, params = make_functional(network_copy)
	return net, params


def empirical_ntk_jacobian_contraction(net, params, x1, x2, compute='full'):
	def single_f(params, x):
		return net(params, x.unsqueeze(0)).squeeze(0)
	# Compute J(x1)
	jac1 = vmap(jacrev(single_f), (None, 0))(params, x1)
	#print(jac1.size())
	jac1 = [j.flatten(2) for j in jac1]

	# Compute J(x2)
	jac2 = vmap(jacrev(single_f), (None, 0))(params, x2)
	jac2 = [j.flatten(2) for j in jac2]

	# Compute J(x1) @ J(x2).T
	einsum_expr = None
	if compute == 'full':
		einsum_expr = 'Naf,Mbf->NMab'
	elif compute == 'trace':
		einsum_expr = 'Naf,Maf->NM'
	elif compute == 'diagonal':
		einsum_expr = 'Naf,Maf->NMa'
	else:
		assert False

	result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
	result = result.sum(0)
	#print(result.size())
	return result

def empirical_ntk_jacobian_contraction_symmetric(net, params, x):
	def single_f(params, x):
		return net(params, x.unsqueeze(0)).squeeze(0)
	# Compute J(x)
	jac = vmap(jacrev(single_f), (None, 0))(params, x)
	jac = [j.flatten(2).detach().cpu() for j in jac]
	return jac

def empirical_ntk_jacobian_contraction_symmetric_savememory(net, params, x):
	def single_f(params, x):
		return net(params, x.unsqueeze(0)).squeeze(0)
	# Compute J(x)
	jac = vmap(jacrev(single_f), (None, 0))(params, x)
	#print(jac1.size())
	jac = [j.flatten(2).detach().bfloat16().cpu() for j in jac]
	return jac

def calculate_class_average_matrix(kernel):
	data_num = kernel.size(0)
	class_num = kernel.size(2)
	matrix = torch.zeros((data_num, data_num))
	for i in range(class_num):
		matrix += kernel[:,:, i, i]
	matrix /= class_num
	return matrix

def calculate_class_ralative_matric(kernel, class_i, class_j):
	'''
	class_i for the training data's label
	class_j for the test data's label
	'''
	data_num = kernel.size(0)
	matrix = torch.zeros((data_num, data_num))
	matrix += kernel[:,:, class_i, class_j]
	return matrix

@jit(forceobj=True)
def cal_kernel_distance(k1, k2):
	trace = np.matrix.trace(np.matmul(k1.T, k2))
	norm1 = np.linalg.norm(k1, ord='fro')
	norm2 = np.linalg.norm(k2, ord='fro')
	if norm1 == 0.0 or norm2 == 0.0:
		return 1.
	distance = trace / (norm1 * norm2)

	return 1. - distance

@jit(forceobj=True)
def erank(w):
	p = w / (w.sum())
	h = (- np.log(p) * p).sum()
	r = np.exp(h)
	return r

@jit(forceobj=True)
def cal_ksm_distance(k1, k2):
	trace = np.matrix.trace(np.matmul(np.matmul(k2.T, k1), k2))

	distance = trace / (np.linalg.norm(k1, ord='fro') * np.linalg.norm(np.matmul(k2, k2.T), ord='fro'))

	return distance

@jit(forceobj=True)
def ksm(kernel, class_i, class_j, label_map):
	'''
	class_i for the selected specific kernel
	class_j for another selected kernel
	study how aligned class_i specific kernel is to the
	labels of class_j compared to other class kernels
	'''
	data_num = kernel.shape[0]
	class_num = kernel.shape[-1]

	onehot_j = np.zeros((data_num, 1))
	mask = np.where(label_map == class_j)
	onehot_j[mask, 0] = 1.0
	class_matrix = np.matmul(onehot_j, onehot_j.T)
	dis_i = 1.0 - cal_kernel_distance(kernel[:, :, class_i, class_i], class_matrix)
	dis_other = 0.0
	for i in range(class_num):
		dis_other += (1.0 - cal_kernel_distance(kernel[:, :, i, i], class_matrix))
	dis_other /= class_num
	#print(dis_i, dis_other)
	if dis_other == 0.0:
		return 0.0
	ks = dis_i / dis_other
	return ks

@jit(forceobj=True)
def ksm_generalization(kernel, class_i, class_j, label_map_trainset, label_map_testset):
	'''
	class_i for the selected specific kernel
	class_j for another selected kernel
	study how aligned class_i specific kernel is to the
	labels of class_j compared to other class kernels
	'''
	data_num = kernel.shape[0]
	class_num = kernel.shape[-1]

	onehot_j_trainset = np.zeros((data_num, 1))
	mask_trainset = np.where(label_map_trainset == class_j)
	onehot_j_trainset[mask_trainset, 0] = 1.0

	onehot_j_testset = np.zeros((data_num, 1))
	mask_testset = np.where(label_map_testset == class_j)
	onehot_j_testset[mask_testset, 0] = 1.0

	class_matrix = np.matmul(onehot_j_testset, onehot_j_trainset.T)
	dis_i = 1.0 - cal_kernel_distance(kernel[:, :, class_i, class_i], class_matrix)
	dis_other = 0.0
	for i in range(class_num):
		dis_other += (1.0 - cal_kernel_distance(kernel[:, :, i, i], class_matrix))
	dis_other /= class_num
	if dis_other == 0.0:
		return 0.0
	ks = dis_i / dis_other
	return ks


def subspace_distance(baise1, baise2):

	pass


